## This file runs LASSO related methods on the voting example

## LASSO -----------
set.seed(123)
X_lasso <- X
J <- ncol(X_lasso)
# create interactions
twoway <- gtools::combinations(n=ncol(X_lasso), r=2, v=colnames(X_lasso))
threeway <- gtools::combinations(n=ncol(X_lasso), r=3, v=colnames(X_lasso))
X_lasso <- cbind(X_lasso, apply(twoway, MARGIN=1, FUN=function(x) X_lasso[, x[1]]*X_lasso[, x[2]]),
              apply(threeway, MARGIN=1, FUN=function(x) X_lasso[, x[1]]*X_lasso[, x[2]]*X_lasso[, x[3]]))
colnames(X_lasso)[(J)+(1:choose((J), 2))] <- apply(twoway, MARGIN=1, FUN=function(x) paste(x, collapse=":"))
colnames(X_lasso)[((J)+choose((J), 2))+(1:choose((J), 3))] <- apply(threeway, MARGIN=1, FUN=function(x) paste(x, collapse=":"))

X_lasso <- as.matrix(X_lasso)
out_l1 <- cv.glmnet(x=X_lasso, y=Y)


## selected rules ---------- 
# nonzero: 46
print("Number of nonzero coefficients (LASSO):")
print(sum(coef(out_l1, s=out_l1$lambda.min)[-1] != 0))
# incl lower order: 172
selected <- colnames(X_lasso[, which(coef(out_l1, s=out_l1$lambda.min)[-1] != 0)])
lower_order <- c()
for (term in selected) {
  split <- strsplit(term, ":")[[1]]
  lower_order <- c(lower_order, split)
  if (length(split)==3) {
    lower_order <- c(lower_order, paste0(split[1], ":", split[2]))
    lower_order <- c(lower_order, paste0(split[1], ":", split[3]))
    lower_order <- c(lower_order, paste0(split[2], ":", split[3]))
  }
}
selected <- unique(c(selected, lower_order))
print("Number of nonzero coefficients (LASSO) if we include lower order terms:")
print(length(selected))

## Run OLS with just variables that LASSO picks
# just nonzero: 7 significant
X_lasso_lm <- X_lasso[, which(coef(out_l1, s=out_l1$lambda.min)[-1] != 0)]
out_lm <- lm(Y ~ X_lasso_lm)
pvals <- summary(out_lm)$coefficients[-1, "Pr(>|t|)"]
pe  <- summary(out_lm)$coefficients[-1, "Estimate"]
print("Significant (OLS) at 5% level:")
print(pe[pvals<.05])
# incl lower order: 13
X_lasso_lm <- X_lasso[, selected]
out_lm <- lm(Y ~ X_lasso_lm)
print("Number of significant (OLS) coefficients at 5% level if we include lower order terms:")
print(sum(summary(out_lm)$coefficients[-1, "Pr(>|t|)"] < .05))
